%% Jaravel and Lashkari algorithm %%%%%%%%%%%%%%
%  All matrices are in (I,N,T) order: 
%  Use polyfit functions to mitigate multicollinearity problems.
function [LJMM_F,LJMM_SC,JLerrorType] = CalLJ_poly(I_vec, B_vec, pvec,paras)

T = size(B_vec,3);
N = size(B_vec,2);
I = size(B_vec,1);

flag_FOerror = 0;
flag_SOerror = 0;
flagSOnonC = 0;

paras.T = T;paras.N = N;paras.I = I;
%% First order algorithm:
Iratio = I_vec(:,:,2:end)./I_vec(:,:,1:end-1);
[pit_n_geo,pt_n_torn]= calP(B_vec,pvec,paras);
% initialization
alp_poly = zeros(paras.Kn+1,1,paras.T) ; % for coefficient
mu_vec   = zeros(2,1,paras.T) ; % for polifit
Lambda_vec = zeros(paras.N,paras.T);
logq_vec = zeros(1,paras.N,paras.T);
logq_vec(1,:,1) = log(I_vec(:,:,1)); % initialize by nominal income

for t = 1:paras.T-1
%%Step 1 Series approximation
q = exp(logq_vec(:,:,t));
Y = log(pit_n_geo(:,:,t))';%Y = pit_n_geo(:,:,t)';
X = log(q)';
[p,~,mu] = polyfit(X,Y, paras.Kn);
alp_poly(:,1,t) = flip(p);
mu_vec(:,1,t) = mu;
%%Step 2 Updating
for tt = 1:t
    Lambda_vec(:,tt) = NH_Corr(q,paras,paras.N,mu_vec(:,1,tt),alp_poly(:,1,tt));
end
Lambda_t1 = sum(Lambda_vec,2)';
logq_vec(:,:,t+1) = logq_vec(:,:,t) + 1./(1+Lambda_t1).*...
        log(Iratio(:,:,t)./pit_n_geo(:,:,t));
end
LJMM_F = exp(logq_vec);

[~, warn_id] = lastwarn;
if isempty(warn_id) == 0
    flag_FOerror = 1;
end

lastwarn('');
%% second order Algorithm
% initial value ; q_vec from algorithm 1
maxiter = 100;
tol  = 10^(-8);

%initialization
q_vec_tau = LJMM_F;
q_vec_tau_old = LJMM_F;
bet_poly = zeros(paras.Kn+1,1,paras.T) ; % for coefficient
mu_vec   = zeros(2,1,paras.T) ; % for polifit

for t = 1:paras.T-1
dif =10;
iter = 0;
while iter < maxiter && dif > tol
% Step (a)
    q_t1 = q_vec_tau(:,:,t+1);
    q_t = q_vec_tau(:,:,t);
    Y2a = log(pit_n_geo(:,:,t))';
    X2a = log(q_t1)';
    [p,~,mu] = polyfit(X2a,Y2a, paras.Kn);
    alp_polyS = flip(p);
%%Step 2 Updating
rho = Cal_rho(q_t1,q_t,paras,N,mu,alp_polyS);
% Step (b)
    Y2b = log(pt_n_torn(:,:,t))' + rho';
    X2b = log(q_t)';
    [p,~,mu] = polyfit(X2b,Y2b, paras.Kn);
    bet_poly(:,1,t) = flip(p);
mu_vec(:,1,t) = mu;

Lambda_vec1 = zeros(N,T);
Lambda_vec2 = zeros(N,T);

for tt = 1:t
    Lambda_vec1(:,tt) = NH_Corr(q_t1,paras,paras.N,mu_vec(:,1,tt),bet_poly(:,1,tt));
end
if t > 1
for tt = 1:t-1
    Lambda_vec2(:,tt) = NH_Corr(q_t,paras,paras.N,mu_vec(:,1,tt),bet_poly(:,1,tt));
end
end
Lambda2S = 1/2*(sum(Lambda_vec1,2)' + sum(Lambda_vec2,2)');

% Step (c)
    q_vec_tau(:,:,t+1) = exp( log(q_vec_tau(:,:,t)) + 1./(1+Lambda2S).*...
         log(Iratio(:,:,t)./pt_n_torn(:,:,t)));

    dif = max(abs(log(q_vec_tau_old(:,:,t+1)) - log(q_vec_tau(:,:,t+1))),[],'includenan');
    if dif == inf
        flagSOnonC = 1 ;
    end

    q_vec_tau_old = q_vec_tau;

    iter = iter + 1;
    if iter == maxiter
        flagSOnonC = 1 ;
    end
end
end

LJMM_SC = q_vec_tau;

[~, warn_id] = lastwarn;
if isempty(warn_id) == 0
    flag_SOerror = 1;
end


JLerrorType = [flag_FOerror,flag_SOerror,flagSOnonC];
end

%% Functions

function [pit_n_geo,pt_n_torn]= calP(B_vec,pvec,paras)
pit_n_geo = ones(1,paras.N,paras.T);
pt_n_torn = ones(1,paras.N,paras.T);
for tt =1:paras.T-1
    pit_n_geo(:,:,tt) = prod((pvec(:,:,tt+1)./pvec(:,:,tt)).^B_vec(:,:,tt));
    pt_n_torn(:,:,tt) = prod((pvec(:,:,tt+1)./pvec(:,:,tt)).^...
        ((B_vec(:,:,tt)+B_vec(:,:,tt+1))/2));
end

end

function lambq = NH_Corr(q,paras,Nsize,mu,coef_poly)
% Note. from documentation of polifit of Matlab. 
%    [P,S,MU] = polyfit(X,Y,N) finds the coefficients of a polynomial in
%    XHAT = (X-MU(1))/MU(2) where MU(1) = MEAN(X) and MU(2) = STD(X). This
%    centering and scaling transformation improves the numerical properties
%    of both the polynomial and the fitting algorithm.

lambqk = zeros(paras.Kn+1,Nsize);
    for k = 0:paras.Kn
    if k ==0
        lambqk(k+1,:) = 0;
    else
        lambqk(k+1,:) = k*coef_poly(k+1).*1/mu(2)^(k)*(log(q)-mu(1)).^((k)-1);% Derivative of logq^k
    end
    end
    lambq = sum(lambqk,1);
end

function rho = Cal_rho(q,q1,paras,Nsize,mu,coef_poly)
rhok = zeros(paras.Kn+1,Nsize);
rho1k = zeros(paras.Kn+1,Nsize);
    for k = 0:paras.Kn
    if k ==0
        rhok(k+1,:) = 0;
    else
        rhok(k+1,:)  = k*coef_poly(k+1).*1/mu(2)^(k)*(log(q) -mu(1)).^((k)-1);% Derivative of logq^k
        rho1k(k+1,:) = k*coef_poly(k+1).*1/mu(2)^(k)*(log(q1)-mu(1)).^((k)-1);
    end
    end
    rho = 1/4.*sum(rhok+rho1k,1).*log(q/q1);
end